import numpy as np
import time
from casadi import *
from Polar import *
from Sector_Parameters import *
from Pre_Jacobian import *

Nr,Nt,Nz,dr,dt,dz,Np,NN,Nx,Nw,Ny,Nv,N_p, Nx_aug=circular_parameters()

#define a model for the sensitivity equation for the states

def state_sensitivity(state_sen, jac_state, jac_pars):
    dhdt = mtimes(jac_state, state_sen) + jac_pars
    return dhdt

def state_sensitivity_np(state_sen, jac_state, jac_pars):
    dhdt = np.matmul(jac_state, state_sen) + jac_pars
    return dhdt

#Add one step that integrates the sensisitivity equation with respect to the state
def state_sensitivity_approx(state_sen, jac_state, jac_pars):
    xx_sen_cur = state_sen
    for i in range(12):
        xx_sen_cur = xx_sen_cur + 60*state_sensitivity_np(xx_sen_cur, jac_state, jac_pars) # 1 min sampling time is used
        pass
    return xx_sen_cur


def ode_sensitivity(x_pre, p_pre, u_pre, ww_pre, i, kc, et, x_sensitivity):
    
    E1=time.time() 
    dfdx_matrix = F1(x_pre, p_pre, u_pre, ww_pre, i, Nt, kc, et)
    
    #print("dfdx matrix --- Start------")
    #print(dfdx_matrix.full().max())
    #print(dfdx_matrix.full().min())
    #print("dfdx matrix --- End------")
    
    dfdp_matrix = F2(x_pre, p_pre, u_pre, ww_pre, i, Nt, kc, et)
    #print("dfdp matrix --- Start------")
    #print(dfdp_matrix.full().max())
    #print(dfdp_matrix.full().min())
    #print("dfdp matrix --- End ------")
    dfdx_init = np.zeros((Nx, Nx)) # The state equation does not explictly depend on the initial state
    #dfdx0_matrix = dfdx_matrix[:,Nx:]
    #df_x_init = np.zeros((Nx_aug, Nx))
    dfdx_act = np.concatenate((dfdx_init, dfdp_matrix), axis=1)
    #xk_new = state_sensitivity_approx(x_sensitivity, dfdx_matrix, dfdx0_matrix)
    xk_new = state_sensitivity_approx(x_sensitivity, dfdx_matrix, dfdx_act)
    
    #print("Sensitivity matrix --- Start------")
    #print(xk_new.full().min())
    #print(xk_new.full().max())
    #print("Sensitivity matrix --- End ------")
    E2=time.time()-E1
    print('Sensitivity ODE takes',E2,' seconds to compute')
    return xk_new
